#include <stdio.h>
#include "matrix.h"

/* This file contains functions for computing inverse of a square matrix:
 * 	A function for lower triangular matrix
 * 	A function for upper triangular matrix
 * 	A function for permutation matrix
 * 	A function for general matrix
 * 	A function for computing LUP decomposition of a matrix
 */



/* Computes the LUP decomposition of size n matrix
 * stored in array M using recursion. if M if invertible,
 * the LUP decomposition gives three matrices: 
 * L, U and P such that M = L * U * P,
 * and L is a lower triangular matrix with its diagonal all 1's,
 * U is an upper triangular matrix, and P is a permutation matrix.
 * The function also returns the determinant of M.
 *
 * The algorithm used is as follows. First M is copied into A.
 * Then a permutation matrix tildeP is constructed such that
 * A = B * tildeP and the first element of first column of B is non-zero.
 * After that B is split into tildeL * C where tildeL is a lower-triangular
 * matrix with all 1's on diagonal and in C all elements in first column
 * except the first one are zero.
 * (A is modfied successively to become B and then C in the program.)
 * C is then copied in Cprime after dropping first row and column.
 * Cprime is recursively decomposed into Lprime * Uprime * Pprime.
 * Matrix L is the product of tildeL and Lprime (extended to size n);
 * matrix U is the matrix Uprime (extended to size n by adding the first
 * row and column of C to it after shuffling the first row accoring to
 * Pprime-transpose; and matrix P is the product of Pprime (extended to size n)
 * and tildeP.
 */
float LUP_decompose(float M[][N], int n, float L[][N], float U[][N], float P[][N])
{
	int i;
	float det_value; // stores the value of determinant
	float A[N][N]; // matrix initialized to M
	float tildeP[N][N]; // A = B * tildeP, B has top-left element non-zero
	float tildeL[N][N]; // A = tildeL * C * tildeP, C has the first column all zero except the first element
	float Cprime[N][N]; // equals C minus the first row and column
	float Lprime[N][N]; // Cprime = Lprime * Uprime * Pprime
	float Uprime[N][N];
	float Pprime[N][N];

	if (n == 1) { // 1 x 1 matrix
		L[0][0] = 1;
		U[0][0] = M[0][0];
		P[0][0] = 1;
		return M[0][0];
	}

	copy_matrix(A, 0, M, 0, n); // copy M to A

	// Initialize tildeP to identity matrix
	set_to_identity(tildeP, n);

	if (A[0][0] == 0) { // first element of first column zero: find a column with non-zero first element and swap
		i = find_nonzero_column(A, n);
		if (i >= n) // no non-zero column
			return 0; // determinant is 0
		swap_column(A, 0, i, n); // swap column #i with column #0, the "new A" equals B now

		// Set matrix tildeP such that "old A" = "new A" * tildeP
		tildeP[0][0] = tildeP[i][i] = 0;
		tildeP[0][i] = tildeP[i][0] = 1;
	}

	/* Construct the lower traingular matrix tildeL. Simultaneously, also
	 * make the first column of A all zero except the first element.
	 * After this, A becomes C.
	 */
	// First set tildeL to identity
	set_to_identity(tildeL, n);
	
	for (int t = 1; t < n; t++) { 
		tildeL[t][0] = A[t][0] / A[0][0]; // record the factor in tildeL[t][0]
		add_row(A[t], A[0], -tildeL[t][0], n); // make first column of A[t] zero
	}
	// At this point, "new A" is C. And "new A" = tildeL * "old A"

	// Drop the first row and column of A
	copy_matrix(Cprime, 0, A, 1, n-1);

	// Recursively decompose Cprime
	det_value = LUP_decompose(Cprime, n-1, Lprime, Uprime, Pprime); // Cprime = Lprime * Uprime * Pprime

	if (is_zero(det_value)) // determinant is zero, so no inverse exists
		return 0.0;
	
	// Compute L = tildeL * (Lprime extended to size n)
	// First extend Lprime to size n and store in L
	set_to_identity(L, n);
	copy_matrix(L, 1, Lprime, 0, n-1);
	// Now multiply tildeL to L and store in L
	multiply_matrix(tildeL, L, n, L); 

	// Compute U = (Uprime extended to size n by adding first row shuffled by Pprime-transpose)
	set_to_identity(U, n);
	U[0][0] = A[0][0];
	for (int k = 1; k < n; k++) {
		for (int t = 1; t < n; t++)
			U[0][k] = U[0][k] + A[0][t] * Pprime[k-1][t-1];
	}
	copy_matrix(U, 1, Uprime, 0, n-1);

	// Compute P = (Pprime extended to size n) * tildeP
	set_to_identity(P, n);
	copy_matrix(P, 1, Pprime, 0, n-1);
	multiply_matrix(P, tildeP, n, P);

	/* Finally, compute the determinant: since the determinant of tildeL is one,
	 * the determinant of A equals (A[0][0] * det_value) * (determinant of tildeP).
	 */
	return det_value * A[0][0] * compute_det_permutation(P, n);
}


/* Inverts the given lower triangular matrix L of size n
 * and stores the inverse in invL.
 * Returns the determinant of the matrix.
 */
float inv_lower(float L[][N], float invL[][N], int n)
{
	float det = 1.0; // stores the determinant of the matrix

	for (int i = 0; i < n; i++) {
		for (int k = i; k >= 0; k--) {
			if (k == i) { // we have invL[i][i] * L[i][i] = 1
				det = det * L[i][i]; // update determinant
				if (det == 0.0) // matrix not invertible
					return det;
				invL[i][i] = 1.0 / L[i][i]; // set the diagonal element
			}
			else { // Use equation \sum_{j=k}^i invL[i][j] * L[j][k] = 0
				invL[i][k] = 0.0; // initialize
				for (int j = i; j > k; j--)
					invL[i][k] = invL[i][k] - invL[i][j] * L[j][k];
				invL[i][k] = invL[i][k] / L[k][k]; // final division
			}
		}
	}

	return det;
}


/* Inverts the given upper triangular matrix U of size n
 * and stores the inverse in invU.
 * Returns the determinant of the matrix.
 */
float inv_upper(float U[][N], float invU[][N], int n)
{
	float det = 1.0; // stores the determinant of the matrix

	for (int i = 0; i < n; i++) {
		for (int k = i; k < n; k++) {
			if (k == i) { // we have invU[i][i] * U[i][i] = 1
				det = det * U[i][i]; // update determinant
				if (det == 0.0) // matrix not invertible
					return det;
				invU[i][i] = 1.0 / U[i][i]; // set the diagonal element
			}
			else { // Use equation \sum_{j=i}^k invU[i][j] * U[j][k] = 0
				invU[i][k] = 0.0; // initialize
				for (int j = i; j < k; j++)
					invU[i][k] = invU[i][k] - invU[i][j] * U[j][k];
				invU[i][k] = invU[i][k] / U[k][k]; // final division
			}
		}
	}

	return det;
}


/* Inverts a size n permutation matrix P and stores the inverse in invP.
 */
void inv_permutation(float P[][N], float invP[][N], int n)
{
	for (int i = 0; i < n; i++)
		for (int j = 0; j < n; j++)
			invP[i][j] = P[j][i]; // inverse is transpose of the matrix
}


/* Computes the inverse of size n matrix A and stores in invA.
 * Returns the determinant of A.
 */
float inv_matrix(float A[][N], float invA[][N], int n)
{
	float det;
	float L[N][N]; // lower triangular
	float invL[N][N]; // inverse of L
	float U[N][N]; // upper triangular
	float invU[N][N]; // inverse of U
	float P[N][N]; // permutation
	float invP[N][N]; // inverse of P

	det = LUP_decompose(A, n, L, U, P);
	if (is_zero(det))
		return det;
	inv_lower(L, invL, n);
	inv_upper(U, invU, n);
	inv_permutation(P, invP, n);

	multiply_matrix(invU, invL, n, invA);
	multiply_matrix(invP, invA, n, invA);

	return det;
}



